# -*- coding: utf-8 -*-

import numpy as np
import scipy.io as sio
import os
import matplotlib
import xlrd
from matplotlib import pyplot as plt
import time
from scipy.optimize import curve_fit
import socket
import HotAtomDist_3D
os.system('chcp 65001')
Host_Name = socket.gethostname()
if Host_Name == 'VictorChengPC':
    Computer_Name = 'X:\\Work\\mail.ustc.edu.cn\\Rb87BEC - Files\\'
elif Host_Name == 'calcenter2':
    Computer_Name = 'D:\\vc\\mail.ustc.edu.cn\\Rb87BEC - Files\\'
font = {'size': 26}
matplotlib.rc('font', **font)


# %%
def poly5(x, a, b, c, d, e, f):
    return a + b * x + c * x ** 2 + d * x ** 3 + e * x ** 4 + f * x ** 5


def poly5p(x, b, d, f):
    return b * x + d * x ** 3 + f * x ** 5


def Gauss(x, a, b, c, d):
    return a * np.exp(-(x - b) ** 2 / (2 * c ** 2)) + d


def poly4(x, a, b, c, d, e):
    return a + b * x + c * x ** 2 + d * x ** 3 + e * x ** 4


def cos(x, a, b, c, d):
    return a * np.cos(b * x + c) + d


def poly6(x, a, b, c, d, e, f, g):
    return a + b * x + c * x ** 2 + d * x ** 3 + e * x ** 4 + f * x ** 5 + g * x ** 6


def poly6p(x, a, c, e, g):
    return a + c * x ** 2 + e * x ** 4 + g * x ** 6


# %% Setup
Start_Time = time.time()
data = np.load(
    Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\"
    "3x3x3;4x41x41x41;V_xz=1.770E_r,V_y=-3.540E_r,Oxy=1.022E_r,Ozy=-1.022E_r,T=1.000e-07K,n=3.000e+18m-3"
    "del_psi=1.000pi,mz=%.3fErno_states_v1.npz" % 0.0)
ku_grid = data['ku_grid']
ky_grid = data['ky_grid']
kv_grid = data['kv_grid']
gridN = len(ku_grid)
num_of_band = 4
mz_list = np.linspace(-0.4, 0.4, 81)
energies_all_mz = np.zeros((num_of_band, gridN, gridN, gridN, len(mz_list)))
ST_all_mz = np.zeros((gridN, gridN, gridN, len(mz_list)))
# ST_yinted_all_mz = np.zeros((gridN, gridN, len(mz_list)))
ESigmaZ_all_mz = np.zeros((num_of_band, gridN, gridN, gridN, len(mz_list)))
Atom_Temperature = 150e-9

# %% Get simed data
mz_simed = 0.25
# data_simed = np.load(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\"
#                                      "3x3x3;4x101x101x101;V_xz=1.770E_r,V_y=-3.540E_r,Oxy=1.022E_r,Ozy=-1.022E_r,"
#                                      "T=1.000e-07K,n=3.000e+18m-3del_psi=1.000pi,"
#                                      "mz=%.3fErno_states_v1.npz" % mz_simed,
#                      allow_pickle=True)
# Weyl_Points = data_simed['Weyl_Points']

# %% Get data
for mz_index in range(len(mz_list)):
    mz = mz_list[mz_index]
    Load_Name = Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\" \
                                "3x3x3;4x41x41x41;V_xz=1.770E_r,V_y=-3.540E_r,Oxy=1.022E_r,Ozy=-1.022E_r," \
                                "T=%.3eK,n=3.000e+18m-3del_psi=1.000pi," \
                                "mz=%.3fErno_states_v1.npz" % (Atom_Temperature, mz)
    if os.path.isfile(Load_Name):
        data = np.load(Load_Name)
        energies = data['energies']
        ESigmaZ = data['ESigmaZ']
        ST_all_mz[..., mz_index] = data['Spin_Texture']
        # ST_yinted_all_mz[..., mz_index] = data['Spin_Texture_yinted']
        ESigmaZ_all_mz[..., mz_index] = ESigmaZ
        energies_all_mz[..., mz_index] = energies
    else:
        print(":(")
        data = np.load(
            Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\"
            "3x3x3;4x41x41x41;V_xz=1.770E_r,V_y=-3.540E_r,Oxy=1.022E_r,Ozy=-1.022E_r,T=1.000e-07K,n=3.000e+18m-3"
            "del_psi=1.000pi,mz=%.3fErno_states_v1.npz" % mz)
        energies = data['energies']
        ESigmaZ = data['ESigmaZ']
        (mu, ns, nsp) = HotAtomDist_3D.ChemicalPotential(energies, Atom_Temperature, 3e18)
        if np.abs(np.sum(ns) - 3e18) / 3e18 > 0.01:
            input('whut')
        ESigmaZ_all_mz[..., mz_index] = ESigmaZ
        Spin_Texture = np.sum(nsp * ESigmaZ, 0)
        energies_all_mz[..., mz_index] = energies
        ST_all_mz[..., mz_index] = Spin_Texture
        Spin_Texture_yinted = np.sum(ESigmaZ * ns, axis=(0, 2)) / np.sum(ns, axis=(0, 2))
        # ST_yinted_all_mz[..., mz_index] = Spin_Texture_yinted
        Data_dict = dict(data)
        Data_dict['Atom_Temperature'] = Atom_Temperature
        Data_dict['Spin_Texture'] = Spin_Texture
        Data_dict['Spin_Texture_yinted'] = Spin_Texture_yinted
        Data_dict['ns'] = ns
        Data_dict['Spin_Texture_y05'] = 0
        np.savez(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\"
                 "3x3x3;4x41x41x41;V_xz=1.770E_r,V_y=-3.540E_r,Oxy=1.022E_r,Ozy=-1.022E_r,T=%.3eK,n=3.000e+18m-3"
                 "del_psi=1.000pi,mz=%.3fErno_states_v1.npz" % (Atom_Temperature, mz), **Data_dict)

# %% Fit for every kukv
# Bad_List_yfixed = []
# Bad_List_mzfixed = []
# Parameters_yfixed = []
# Parameters_mzfixed = []
# (kus, kvs) = np.meshgrid(range(41), range(41), indexing='ij')
# kus = np.reshape(kus, (np.size(kus), 1), order='F')
# kvs = np.reshape(kvs, (np.size(kvs), 1), order='F')
#
# for ii in range(gridN * gridN):
#     kukv = (kus[ii][0], kvs[ii][0])
#     ESZ_Show_0 = ESigmaZ_all_mz[0, kukv[0], :, kukv[1], :]
#     ST_Show = ST_all_mz[kukv[0], :, kukv[1], :]
#     mz_show = 0.2
#     ST_Show_mzfixed = ST_Show[:, np.where(np.abs(mz_list - mz_show) < 1e-10)[0][0]]
#     ky_show = 0.5
#     ST_Show_yfixed = ST_Show[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
#     ESZ_Show_yfixed = ESZ_Show_0[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
#     ST_yinted_Show = ST_yinted_all_mz[kukv[0], kukv[1], :]
#
#     # %%
#     '''
#     plt.plot(mz_list, ST_yinted_Show, '-*')
#     plt.title('intged, ku=%.3fpi, kv=%.3fpi'%(ku_grid[kukv[0]], kv_grid[kukv[1]]))
#     plt.xlabel('mz/Er')
#     plt.show()
#
#     plt.plot(mz_list, ST_Show_yfixed, '-*')
#     plt.title('Spin Texture, ky=pi/2, ku=%.3fpi, kv=%.3fpi'%(ku_grid[kukv[0]], kv_grid[kukv[1]]))
#     plt.xlabel('mz/Er')
#     plt.show()
#
#     plt.plot(mz_list, ESZ_Show_yfixed, '-*')
#     plt.title('ESigmaZ, n=0, ky=pi/2, ku=%.3fpi, kv=%.3fpi'%(ku_grid[kukv[0]], kv_grid[kukv[1]]))
#     plt.xlabel('mz/Er')
#     plt.show()
#
#     plt.plot(np.cos(ky_grid), ST_Show_mzfixed, '-*')
#     plt.title('mz=0.2Er, ku=%.3fpi, kv=%.3fpi'%(ku_grid[kukv[0]], kv_grid[kukv[1]]))
#     plt.xlabel('ky/pi')
#     plt.show()
#     '''
#     # %% Fit S(mz)
#     Fit_Func_yfixed = poly5
#     (par_yfixed, pcov_yfixed) = curve_fit(Fit_Func_yfixed, mz_list, ST_Show_yfixed)
#     Parameters_yfixed.append(par_yfixed)
#     perr_yfixed = np.sqrt(np.diag(pcov_yfixed))
#     ST_Show_yfixed_fit = Fit_Func_yfixed(mz_list, *par_yfixed)
#     ybar_yfixed = np.sum(ST_Show_yfixed) / len(ST_Show_yfixed)
#     SS_tot_yfixed = np.sum((ST_Show_yfixed - ybar_yfixed) ** 2)
#     SS_res_yfixed = np.sum((ST_Show_yfixed - ST_Show_yfixed_fit) ** 2)
#     R2_yfixed = 1 - SS_res_yfixed / SS_tot_yfixed
#     #    if np.any(np.logical_and(np.abs(perr_yfixed/par_yfixed) > 0.1, np.abs(par_yfixed) > 1e-5)):
#     if R2_yfixed < 0.95:  # or np.abs(par_yfixed[-1]) > 3 or 1:
#         if np.mod(ii, 16) == 0:
#             if np.mod(ii // 16, 20) == 0:
#                 plt.show()
#                 fig = plt.figure(ii // 16 // 20, figsize=(26, 13))
#             plt.subplot(4, 5, np.mod(ii // 16, 20) + 1)
#             plt.plot(mz_list, ST_Show_yfixed, '*', mz_list, ST_Show_yfixed_fit)
#             #        plt.title('Spin Texture, ky=pi/2, ku=%.3fpi, kv=%.3fpi'%(ku_grid[kukv[0]], kv_grid[kukv[1]]))
#             plt.title(
#                 'ku=%.2f,kv=%.2f,%.3f,%.3f,%.3f,%.3f,%.3f,%.3f' % (ku_grid[kukv[0]], kv_grid[kukv[1]], *par_yfixed))
#             plt.xlabel('mz/Er')
#             plt.subplots_adjust(top=0.971, bottom=0.045, left=0.022, right=0.992, hspace=0.335, wspace=0.139)
#             #        plt.show()
#             Bad_List_yfixed.append(kukv)
#
#     # %% Fit S(ky)
#     Fit_Func_mzfixed = poly6
#     (par_mzfixed, pcov_mzfixed) = curve_fit(Fit_Func_mzfixed, ky_grid, ST_Show_mzfixed)
#     perr_mzfixed = np.sqrt(np.diag(pcov_mzfixed))
#     ST_Show_mzfixed_fit = Fit_Func_mzfixed(ky_grid, *par_mzfixed)
#     ybar_mzfixed = np.sum(ST_Show_mzfixed) / len(ST_Show_mzfixed)
#     SS_tot_mzfixed = np.sum((ST_Show_mzfixed - ybar_mzfixed) ** 2)
#     SS_res_mzfixed = np.sum((ST_Show_mzfixed - ST_Show_mzfixed_fit) ** 2)
#     R2_mzfixed = 1 - SS_res_mzfixed / SS_tot_mzfixed
#     #    if np.any(np.logical_and(np.abs(perr_mzfixed/par_mzfixed) > 0.1, np.abs(par_mzfixed) > 1e-5)):
#     if R2_mzfixed < 0.95:
#         plt.plot(ky_grid, ST_Show_mzfixed, '*', ky_grid, ST_Show_mzfixed_fit)
#         plt.title('mz=0.2Er, ku=%.3fpi, kv=%.3fpi' % (ku_grid[kukv[0]], kv_grid[kukv[1]]))
#         plt.xlabel('ky/pi')
#         plt.show()
#         Bad_List_mzfixed.append(kukv)
#     Parameters_mzfixed.append(par_mzfixed)
#
# # %% Check parameters for all kukv
# Parameters_yfixed = np.array(Parameters_yfixed)
# Parameters_yfixed = np.reshape(Parameters_yfixed, (gridN, gridN, np.shape(Parameters_yfixed)[-1]))
# Parameters_mzfixed = np.array(Parameters_mzfixed)
# Parameters_mzfixed = np.reshape(Parameters_mzfixed, (gridN, gridN, np.shape(Parameters_mzfixed)[-1]))
# Parameters_yfixed_label = ['a', 'b*x', 'c*x^2', 'd*x^3', 'e*x^4', 'f*x^5']
# for yfixedn in range(np.shape(Parameters_yfixed)[-1]):
#     Parameters_yfixed_n = Parameters_yfixed[..., yfixedn]
#     fig = plt.figure()
#     plt.imshow(Parameters_yfixed_n, cmap='seismic')  # , vmin=-0.5, vmax=0.5)
#     plt.title(Parameters_yfixed_label[yfixedn])
#     plt.colorbar()
#     plt.show()

pass

#%% Fit one kukv
if mz_simed > 0:
    kukv = (20, 20)
elif mz_simed < 0:
    kukv = (0, 0)
else:
    kukv = (20, 40)
kukv_pre = kukv
ESZ_Show_0 = ESigmaZ_all_mz[0, kukv[0], :, kukv[1], :]
ST_Show = ST_all_mz[kukv[0], :, kukv[1], :]
mz_show = mz_simed
ST_Show_mzfixed = ST_Show[:, np.where(np.abs(mz_list - mz_show) < 1e-10)[0][0]]
ky_show = 0.5
ST_Show_yfixed = ST_Show[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
ESZ_Show_yfixed = ESZ_Show_0[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
ESZ_Show_mzfixed = ESZ_Show_0[:, np.where(np.abs(mz_list - mz_show) < 1e-10)[0][0]]
# ST_yinted_Show = ST_yinted_all_mz[kukv[0], kukv[1], :]

# plt.figure()
# plt.plot(mz_list[0:50], ESZ_Show_yfixed[0:50], '*')
# plt.xlim(min(mz_list[0:50]), max(mz_list[0:50]))
# plt.ylim(min(ESZ_Show_yfixed[0:50]), max(ESZ_Show_yfixed[0:50]))
# plt.xlabel('mz_list')
# plt.ylabel('Spin')
# plt.title('ESZ_Show_yfixed')
# plt.show()

plt.figure(figsize=(20, 20))
plt.plot(mz_list, ESZ_Show_yfixed, '*')
plt.xlim(min(mz_list), max(mz_list))
plt.ylim(min(ESZ_Show_yfixed), max(ESZ_Show_yfixed))
plt.xlabel('mz_list')
plt.ylabel('Spin')
plt.title('ESZ_Show_yfixed,T=%.3e' % Atom_Temperature)
plt.show()

Fit_Func_yfixed = poly5
(par_yfixed, pcov_yfixed) = curve_fit(Fit_Func_yfixed, mz_list, ST_Show_yfixed)
perr_yfixed = np.sqrt(np.diag(pcov_yfixed))
ST_Show_yfixed_fit = Fit_Func_yfixed(mz_list, *par_yfixed)
ybar_yfixed = np.sum(ST_Show_yfixed) / len(ST_Show_yfixed)
SS_tot_yfixed = np.sum((ST_Show_yfixed - ybar_yfixed) ** 2)
SS_res_yfixed = np.sum((ST_Show_yfixed - ST_Show_yfixed_fit) ** 2)
R2_yfixed = 1 - SS_res_yfixed / SS_tot_yfixed
plt.figure(figsize=(20, 20))
plt.plot(mz_list, ST_Show_yfixed, '*', markersize=20)
plt.plot(mz_list, ST_Show_yfixed_fit, label='kukv:%s' % str(kukv))
plt.xlim(min(mz_list), max(mz_list))
plt.ylim(min(np.concatenate((ST_Show_yfixed, ST_Show_yfixed_fit))),
         max(np.concatenate((ST_Show_yfixed, ST_Show_yfixed_fit))))
plt.xlabel('mz_list')
plt.ylabel('Spin')
# plt.title("s=%.3f+%.3fx+%.3fx^2+%.3fx^3+%.3fx^4+%.3fx^5" % tuple(par for par in par_yfixed))
plt.title("s=%.3f+%.3fx+%.3fx^2+%.3fx^3+%.3fx^4+%.3fx^5,T=%.3e" % (*par_yfixed, Atom_Temperature))
plt.legend(loc='upper right')
plt.show()

Fit_Func_mzfixed = poly6
(par_mzfixed, pcov_mzfixed) = curve_fit(Fit_Func_mzfixed, ky_grid, ST_Show_mzfixed)
perr_mzfixed = np.sqrt(np.diag(pcov_mzfixed))
ST_Show_mzfixed_fit = Fit_Func_mzfixed(ky_grid, *par_mzfixed)

ky_grid_ext = np.linspace(-1, 2, 10001)
ST_Show_mzfixed_fit_ext = Fit_Func_mzfixed(ky_grid_ext, *par_mzfixed)

ybar_mzfixed = np.sum(ST_Show_mzfixed) / len(ST_Show_mzfixed)
SS_tot_mzfixed = np.sum((ST_Show_mzfixed - ybar_mzfixed) ** 2)
SS_res_mzfixed = np.sum((ST_Show_mzfixed - ST_Show_mzfixed_fit) ** 2)
R2_mzfixed = 1 - SS_res_mzfixed / SS_tot_mzfixed
plt.figure(figsize=(20, 20))
plt.plot(ky_grid, ST_Show_mzfixed, '*', markersize=20)
plt.plot(ky_grid, ST_Show_mzfixed_fit, label='kukv:%s' % str(kukv))
plt.xlim(min(ky_grid), max(ky_grid))
plt.ylim(min(np.concatenate((ST_Show_mzfixed, ST_Show_mzfixed_fit))),
         max(np.concatenate((ST_Show_mzfixed, ST_Show_mzfixed_fit))))
plt.xlabel('kz')
plt.ylabel('Spin')
plt.legend(loc='upper right')
# plt.title("s=%.3f+%.3fx+%.3fx^2+%.3fx^3+%.3fx^4+%.3fx^5+%.3fx^6" % tuple(par for par in par_mzfixed))
plt.title("s=%.3f+%.3fx+%.3fx^2+%.3fx^3+%.3fx^4+%.3fx^5+%.3fx^6,T=%.3e" % (*par_mzfixed, Atom_Temperature))
plt.show()


# %%
def ky_to_mz(ky):
    s_ky = Fit_Func_mzfixed(ky, *par_mzfixed)
    p = par_yfixed[::-1].copy()
    p[-1] -= s_ky
    mz_roots = np.roots(p)
    mz_cor = np.real(mz_roots[np.where(np.imag(mz_roots) == 0)[0]])
    # s_mz = Fit_Func_yfixed(mz_cor, *par_yfixed)
    return mz_cor


def mz_to_ky(mz_):
    s_mz = Fit_Func_yfixed(mz_, *par_yfixed)
    p = par_mzfixed[::-1].copy()
    p[-1] -= s_mz
    ky_roots = np.roots(p)
    ky_cor = np.real(ky_roots[np.where(np.abs(np.imag(ky_roots)) < 1e-8)[0]])
    if len(ky_cor) == 0:
        print('what')
        return np.array([100])
    elif len(ky_cor) == 1:
        print('emm')
        return ky_cor
    else:
        ky_good = ky_cor[np.logical_and(ky_cor <= 1, ky_cor >= 0).nonzero()]
        if len(ky_good) == 0:
            ky_good = np.array([100])
        return ky_good


# %% Fit for multiple kukv (new)
# kukv_list = [(0, 0), (20, 0), (20, 20), (0, 20), (10, 10)]
# kukv_list = [(13, 11)]

kukv_list = []
(kus, kvs) = np.meshgrid(range(41), range(41), indexing='ij')
kus = np.reshape(kus, (np.size(kus), 1), order='F')
kvs = np.reshape(kvs, (np.size(kvs), 1), order='F')
for kukv_index in range(gridN * gridN):
    kukv = (kus[kukv_index][0], kvs[kukv_index][0])
    kukv_list.append(kukv)
# ky_measure_list = np.linspace(0, 1, 101)
# mz_measure_list_allkukv = np.zeros((len(ky_measure_list), len(kukv_list)))
# for kukv_index in range(len(kukv_list)):
#     kukv = kukv_list[kukv_index]
#     ST_Show = ST_all_mz[kukv[0], :, kukv[1], :]
#     mz_show = mz_simed
#     ST_Show_mzfixed = ST_Show[:, np.where(np.abs(mz_list - mz_show) < 1e-10)[0][0]]
#     ky_show = 0.5
#     ST_Show_yfixed = ST_Show[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
#     Fit_Func_yfixed = poly5
#     (par_yfixed, pcov_yfixed) = curve_fit(Fit_Func_yfixed, mz_list, ST_Show_yfixed)
#     Fit_Func_mzfixed = poly6
#     (par_mzfixed, pcov_mzfixed) = curve_fit(Fit_Func_mzfixed, ky_grid, ST_Show_mzfixed)
#     for ky_measure_index in range(len(ky_measure_list)):
#         ky_measure = ky_measure_list[ky_measure_index]
#         s_ky_measure = Fit_Func_mzfixed(ky_measure, *par_mzfixed)
#         p_measure = par_yfixed[::-1].copy()
#         p_measure[-1] -= s_ky_measure
#         mz_roots_measure = np.roots(p_measure)
#         mz_measure = np.real(mz_roots_measure[np.where(np.imag(mz_roots_measure) == 0)[0]])
#         mz_measure_list_allkukv[ky_measure_index, kukv_index] = mz_measure
# plt.figure()
# for kukv_index in range(len(kukv_list)):
#     plt.plot(ky_measure_list, mz_measure_list_allkukv[:, kukv_index]*2)
# plt.xlabel('ky')
# plt.ylabel('detuning')
# plt.show()

Data_Table = xlrd.open_workbook(Computer_Name + 'Rb87files\\3DSOC_spin_texture\\spin_texture.xlsx').sheets()[0]
Meta_Table = xlrd.open_workbook(Computer_Name + 'Rb87files\\3DSOC_spin_texture\\spin_texture.xlsx').sheets()[1]
Rows = np.concatenate((np.arange(33, 71 + 1), [217]))
detuning_measure_list = np.array([Meta_Table.row_values(Row - 1)[2] for Row in Rows])
mz_measure_list = detuning_measure_list / 2
ky_good_what = []
ky_measure_list_allkukv = np.zeros((len(mz_measure_list), len(kukv_list)))
for kukv_index in range(len(kukv_list)):
    kukv = kukv_list[kukv_index]
    ST_Show = ST_all_mz[kukv[0], :, kukv[1], :]
    mz_show = mz_simed
    ST_Show_mzfixed = ST_Show[:, np.where(np.abs(mz_list - mz_show) < 1e-10)[0][0]]
    ky_show = 0.5
    ST_Show_yfixed = ST_Show[np.where(np.abs(ky_grid - ky_show) < 1e-10)[0][0], :]
    Fit_Func_yfixed = poly5
    (par_yfixed, pcov_yfixed) = curve_fit(Fit_Func_yfixed, mz_list, ST_Show_yfixed)
    Fit_Func_mzfixed = poly6
    (par_mzfixed, pcov_mzfixed) = curve_fit(Fit_Func_mzfixed, ky_grid, ST_Show_mzfixed)
    for mz_measure_index in range(len(mz_measure_list)):
        mz_measure = mz_measure_list[mz_measure_index]
        s_mz = Fit_Func_yfixed(mz_measure, *par_yfixed)
        p = par_mzfixed[::-1].copy()
        p[-1] -= s_mz
        ky_roots = np.roots(p)
        ky_cor = np.real(ky_roots[np.where(np.abs(np.imag(ky_roots)) < 1e-8)[0]])
        if len(ky_cor) == 0:
            if mz_measure <= -0.35:
                ky_good = np.array([0])
            elif mz_measure >= -0.05:
                ky_good = np.array([1])
        elif len(ky_cor) == 1:
            print('emm')
            ky_good = ky_cor
        else:
            ky_good = ky_cor[np.logical_and(ky_cor <= 1, ky_cor >= 0).nonzero()]
            if len(ky_good) > 1:
                ky_good_what.append(ky_good)
                ky_good = np.array([ky_good[0]])
            if len(ky_good) == 0:
                if mz_measure <= -0.35:
                    ky_good = np.array([0])
                elif mz_measure >= -0.05:
                    ky_good = np.array([1])
        ky_measure_list_allkukv[mz_measure_index, kukv_index] = ky_good

    detuning_ky_pair = np.zeros((len(detuning_measure_list), 3))
    detuning_ky_pair[:, 0] = detuning_measure_list
    detuning_ky_pair[:, 1] = ky_measure_list_allkukv[:, kukv_index]
    detuning_ky_pair[:, 2] = Rows
    sio.savemat(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\detuning_ky_pair,mz_sim=%.3f_%dT%.3e.mat" %
                (mz_show, kukv_index, Atom_Temperature),
                {'detuning_ky_pair': detuning_ky_pair, 'Try': 5})
sio.savemat(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\detuning_ky_pairs_all,mz_sim=%.3fT%.3e.mat" %
            (mz_show, Atom_Temperature),
            {'mz_measure_list': mz_measure_list, 'ky_measure_list_allkukv': ky_measure_list_allkukv,
             'Try': 5, 'ky_good_what': ky_good_what})
fig = plt.figure(figsize=(20, 20))
for kukv_index in range(len(kukv_list)):
    use_l = (ky_measure_list_allkukv[:, kukv_index] != 100).nonzero()
    mz_see = mz_measure_list[use_l]
    ky_see = np.squeeze(ky_measure_list_allkukv[use_l, kukv_index])
    mz_sort = list(np.argsort(mz_see))
    plt.plot(mz_see[mz_sort], ky_see[mz_sort], '*-k', markersize=20)  # , label=str(kukv_list[kukv_index]))
plt.xlabel('mz')
plt.ylabel('ky')
plt.ylim([0, 1])
plt.xlim([np.min(mz_measure_list), np.max(mz_measure_list)])
plt.xticks(np.linspace(np.min(mz_measure_list), np.max(mz_measure_list), 11))
plt.yticks(np.linspace(0, 1, 11))
plt.title('T=%.3e' % Atom_Temperature)
# plt.legend(loc='upper right')
plt.show()
ky_good_what = np.array(ky_good_what)

#%%
# checks = np.zeros((len(mz_measure_list), 2))
# checks[:, 0] = mz_measure_list
# checks[:, 1] = np.squeeze(ky_measure_list_allkukv)


# def forceAspect(ax_, aspect=1):
#     im = ax_.get_images()
#     extent = im[0].get_extent()
#     ax_.set_aspect(abs((extent[1]-extent[0])/(extent[3]-extent[2]))/aspect)
#
#
# fig = plt.figure(figsize=(10, 10))
# ax = fig.add_subplot(111)
# plt.imshow(ky_good_what)
# forceAspect(ax, aspect=1)
# plt.colorbar()
# plt.show()
# fig = plt.figure(figsize=(10, 10))
# plt.plot(ky_good_what[:, 1] / ky_good_what[:, 0])
# plt.show()

T_list = np.array([50e-9, 100e-9, 150e-9, 200e-9])
colors = ['b', 'r', 'y', 'g']
fig = plt.figure(figsize=(20, 20))
for T_index in range(len(T_list)):
    Atom_Temperature = T_list[T_index]
    Load_Name = Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\detuning_ky_pairs_all,mz_sim=%.3fT%.3e.mat" % \
        (mz_show, Atom_Temperature)
    data = sio.loadmat(Load_Name)
    ky_measure_list_allkukv = data['ky_measure_list_allkukv']
    for kukv_index in range(len(kukv_list)):
        use_l = (ky_measure_list_allkukv[:, kukv_index] != 100).nonzero()
        mz_see = mz_measure_list[use_l]
        ky_see = np.squeeze(ky_measure_list_allkukv[use_l, kukv_index])
        mz_sort = list(np.argsort(mz_see))
        if kukv_index == 0:
            plt.plot(mz_see[mz_sort], ky_see[mz_sort], '-k', markersize=20, color=colors[T_index],
                     label='%.3eK' % Atom_Temperature)
        else:
            plt.plot(mz_see[mz_sort], ky_see[mz_sort], '-k', markersize=20, color=colors[T_index])
plt.xlabel('mz')
plt.ylabel('ky')
plt.ylim([0, 1])
plt.xlim([np.min(mz_measure_list), np.max(mz_measure_list)])
plt.xticks(np.linspace(np.min(mz_measure_list), np.max(mz_measure_list), 11))
plt.yticks(np.linspace(0, 1, 11))
# plt.title('T=%.3e' % Atom_Temperature)
plt.legend(loc='upper left')
plt.show()
# %%
Data_Table = xlrd.open_workbook(Computer_Name + 'Rb87files\\3DSOC_spin_texture\\spin_texture.xlsx').sheets()[0]
Meta_Table = xlrd.open_workbook(Computer_Name + 'Rb87files\\3DSOC_spin_texture\\spin_texture.xlsx').sheets()[1]
# detuning_measure_list = np.array([-0.02, -0.04, -0.06, -0.08, -0.1, -0.12, -0.14, -0.16, -0.18, -0.2, -0.22, -0.24,
#                                   -0.26, -0.28, -0.3, -0.32, -0.34, -0.36, -0.38, -0.4, -0.42, -0.44, -0.46, -0.48,
#                                   -0.5, -0.52, -0.54, -0.56, -0.58, -0.6, -0.05, -0.15, -0.25, -0.35, -0.45, -0.55,
#                                   -0.65, -0.7, -0.75, -0.8, -0.85, -0.9, -0.95, -1, -0.782])
# Rows = np.concatenate((np.arange(2, 76 + 1), [217]))
# Rows = np.arange(202, 222 + 1)

# mz_simed = 0.25
Rows = np.concatenate((np.arange(185, 192 + 1), np.arange(156, 168 + 1)))
detuning_measure_list = np.array([Meta_Table.row_values(Row - 1)[2] for Row in Rows])
Try = 5
#
mz_measure_list = detuning_measure_list / 2
ky_measure_list_caled = []
for mz_measure_index in range(len(mz_measure_list)):
    mz_measure = mz_measure_list[mz_measure_index]
    ky_measure_list_caled.append(mz_to_ky(mz_measure))
ky_measure_list_caled = np.squeeze(np.array(ky_measure_list_caled))
detuning_ky_pair = np.zeros((len(detuning_measure_list), 3))
detuning_ky_pair[:, 0] = detuning_measure_list
detuning_ky_pair[:, 1] = ky_measure_list_caled
detuning_ky_pair[:, 2] = Rows
np.savetxt(Computer_Name + "Rb87Data\\3DSOCbandData\\200117\\detuning_ky_pair,mz_sim=%.3f.txt" % mz_show,
           detuning_ky_pair, fmt='%.3f', header='detuning ky row')
sio.savemat(Computer_Name + "Rb87Data\\3DSOCbandData\\200117\\detuning_ky_pair,mz_sim=%.3f.mat" % mz_show,
            {'detuning_ky_pair': detuning_ky_pair, 'Try': Try})

# %%
ky_measure_list = np.linspace(0, 1, 21)
# ky_measure_list = np.array([0, 0.1, 0.2, 0.3, 0.39, 0.5, 0.61, 0.7, 0.8, 0.9, 1])
mz_measure_list_caled = []
for ky_measure_index in range(len(ky_measure_list)):
    ky_measure = ky_measure_list[ky_measure_index]
    mz_measure_list_caled.append(ky_to_mz(ky_measure))

mz_measure_list_caled = np.array(mz_measure_list_caled)
detuning_measure_list_caled = mz_measure_list_caled * 2
for detuning_measure in detuning_measure_list_caled:
    print('%.2f' % detuning_measure)

ky_detuning_pair = np.zeros((len(ky_measure_list), 2))
ky_detuning_pair[:, 0] = np.squeeze(ky_measure_list)
ky_detuning_pair[:, 1] = np.squeeze(detuning_measure_list_caled)
np.savetxt(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\ky_detuning_pair,mz_sim=%.3f.txt" % mz_show,
           ky_detuning_pair, fmt='%.3f', header='ky detuning')
sio.savemat(Computer_Name + "Rb87Data\\3DSOCbandData\\190827\\ky_detuning_pair,mz_sim=%.3f.mat" % mz_show,
            {'ky_detuning_pair': ky_detuning_pair})

# # %%
# End_Time = time.time()
# print('Start: ', Start_Time)
# print('End: ', End_Time)
# print('Total: ', End_Time - Start_Time)
